On-Time Flight Performance with Spark and Cosmos DB (Seattle)

On-Time Flight Performance Background

This notebook provides an analysis of On-Time Flight Performance and Departure Delays data using GraphFrames for Apache Spark.

Source Data:

References:

Spark to Cosmos DB Connector

Connecting Apache Spark to Azure Cosmos DB accelerates your ability to solve your fast moving Data Sciences problems where your data can be quickly persisted and retrieved using Azure Cosmos DB's DocumentDB API. With the Spark to Cosmos DB conector, you can more easily solve scenarios including (but not limited to) blazing fast IoT scenarios, update-able columns when performing analytics, push-down predicate filtering, and performing advanced analytics to data sciences against your fast changing data against a geo-replicated managed document store with guaranteed SLAs for consistency, availability, low latency, and throughput.

The Spark to Cosmos DB connector utilizes the Azure DocumentDB Java SDK will utilize the following flow:

The data flow is as follows:

  1. Connection is made from Spark master node to Cosmos DB gateway node to obtain the partition map. Note, user only specifies Spark and Cosmos DB connections, the fact that it connects to the respective master and gateway nodes is transparent to the user.
  2. This information is provided back to the Spark master node. At this point, we should be able to parse the query to determine which partitions (and their locations) within Cosmos DB we need to access.
  3. This information is transmitted to the Spark worker nodes ...
  4. Thus allowing the Spark worker nodes to connect directly to the Cosmos DB partitions directly to extract the data that is needed and bring the data back to the Spark partitions within the Spark worker nodes.
In [1]:
%%configure
{ "name":"Spark-to-Cosmos_DB_Connector", 
  "executorMemory": "8G", 
  "executorCores": 2, 
  "numExecutors": 10,
  "driverMemory" : "2G",
  "jars": ["wasb:///example/jars/0.0.4/azure-cosmosdb-spark_2.1.0_2.11-0.0.4.jar", "wasb:///example/jars/0.0.4/azure-documentdb-1.13.0.jar", "wasb:///example/jars/0.0.4/azure-documentdb-rx-0.9.0-rc2.jar", "wasb:///example/jars/0.0.4/json-20140107.jar", "wasb:///example/jars/0.0.4/rxjava-1.3.0.jar", "wasb:///example/jars/0.0.4/rxnetty-0.4.20.jar"],  
  "conf": {
    "spark.jars.packages": "graphframes:graphframes:0.5.0-spark2.1-s_2.11",   
    "spark.jars.excludes": "org.scala-lang:scala-reflect"
   }
}
Current session configs: {u'kind': 'pyspark', u'name': u'Spark-to-Cosmos_DB_Connector', u'driverMemory': u'2G', u'numExecutors': 10, u'conf': {u'spark.jars.packages': u'graphframes:graphframes:0.5.0-spark2.1-s_2.11', u'spark.jars.excludes': u'org.scala-lang:scala-reflect'}, u'executorCores': 2, u'jars': [u'wasb:///example/jars/0.0.4/azure-cosmosdb-spark_2.1.0_2.11-0.0.4.jar', u'wasb:///example/jars/0.0.4/azure-documentdb-1.13.0.jar', u'wasb:///example/jars/0.0.4/azure-documentdb-rx-0.9.0-rc2.jar', u'wasb:///example/jars/0.0.4/json-20140107.jar', u'wasb:///example/jars/0.0.4/rxjava-1.3.0.jar', u'wasb:///example/jars/0.0.4/rxnetty-0.4.20.jar'], u'executorMemory': u'8G'}
IDYARN Application IDKindStateSpark UIDriver logCurrent session?
41application_1505703138288_0088pysparkidleLinkLink
43application_1505703138288_0090sparkidleLinkLink
In [2]:
# Connection
flightsConfig = {
"Endpoint" : "https://pass-cosmosdb.documents.azure.com:443/",
"Masterkey" : "RMZXBrmg60L6tlwx52stSzt4r97WkZ2BXrI7rjuhRvZ5hiUcjiH6zmUoWeHLEqPhtv6mQE3gi8tqFOxfn97kMQ==",
"Database" : "flights",
"preferredRegions" : "Central US",
"Collection" : "departuredelays", 
"SamplingRatio" : "1.0",
"schema_samplesize" : "1000",
"query_pagesize" : "200000",
"query_custom" : "SELECT c.date, c.delay, c.distance, c.origin, c.destination FROM c"
}
Starting Spark application
IDYARN Application IDKindStateSpark UIDriver logCurrent session?
45application_1505703138288_0094pysparkidleLinkLink
SparkSession available as 'spark'.
In [3]:
flights = spark.read.format("com.microsoft.azure.cosmosdb.spark").options(**flightsConfig).load()
flights.count()
flights.cache()
DataFrame[origin: string, delay: int, distance: int, destination: string, date: int]
In [4]:
flights.createOrReplaceTempView("flights")

Obtaining airport code information

In [5]:
# Set File Paths
airportsnaFilePath = "wasb://data@doctorwhostore.blob.core.windows.net/airport-codes-na.txt"

# Obtain airports dataset
airportsna = spark.read.csv(airportsnaFilePath, header='true', inferSchema='true', sep='\t')
airportsna.createOrReplaceTempView("airports")

Flights departing from Seattle

In [6]:
%%sql
select count(1) from flights where origin = 'SEA'
count(1)
0 23078

Top 10 Delayed Destinations originating from Seattle

In [7]:
%%sql
select concat(concat((dense_rank() OVER (PARTITION BY 1 ORDER BY TotalDelays DESC)-1), '. '), destination) as destination, TotalDelays
from (
select a.city as destination, sum(f.delay) as TotalDelays, count(1) as Trips
from flights f
join airports a
  on a.IATA = f.destination
where f.origin = 'SEA'
and f.delay > 0
group by a.city 
order by sum(delay) desc limit 10
) a
destination TotalDelays
0 0. San Francisco 26026
1 1. Denver 16058
2 2. Los Angeles 14038
3 3. Chicago 13738
4 4. Las Vegas 11718
5 5. Phoenix 11388
6 6. Oakland 10085
7 7. Dallas 8917
8 8. Anchorage 8255
9 9. Salt Lake City 7867

Calculate median delays by destination cities departing from Seattle

In [8]:
%%sql
select a.city as destination, percentile_approx(f.delay, 0.5) as median_delay
from flights f
join airports a
  on a.IATA = f.destination
where f.origin = 'SEA'
group by a.city 
order by percentile_approx(f.delay, 0.5)
destination median_delay
0 Jackson Hole -6
1 Fresno -5
2 Long Beach -5
3 Omaha -5
4 Portland -4
5 Orange County -4
6 Santa Barbara -4
7 Burbank -4
8 San Diego -3
9 Ontario -3
10 Colorado Springs -3
11 Lihue, Kauai -3
12 Palm Springs -3
13 Kansas City -3
14 Kahului, Maui -3
15 Phoenix -2
16 Los Angeles -2
17 St. Louis -2
18 San Jose -2
19 San Antonio -2
20 Las Vegas -2
21 Honolulu, Oahu -2
22 Ketchikan -2
23 Washington DC -2
24 Fairbanks -2
25 Austin -2
26 Minneapolis -2
27 Tucson -1
28 Salt Lake City -1
29 Anchorage -1
30 Detroit -1
31 Fort Lauderdale -1
32 Miami -1
33 Spokane -1
34 Orlando -1
35 Philadelphia -1
36 Sacramento -1
37 Juneau -1
38 Charlotte -1
39 Boston -1
40 San Francisco -1
41 New York -1
42 Houston 0
43 Atlanta 0
44 Oakland 0
45 Newark 0
46 Milwaukee 0
47 Denver 0
48 Dallas 0
49 Chicago 0
50 Cincinnati 1
51 Reno 1
52 Albuquerque 5
53 Cleveland 12

Building up a GraphFrames

Using GraphFrames for Apache Spark to run degree and motif queries against Cosmos DB

In [9]:
# Build `departureDelays` DataFrame
departureDelays = spark.sql("select cast(f.date as int) as tripid, cast(concat(concat(concat(concat(concat(concat('2014-', concat(concat(substr(cast(f.date as string), 1, 2), '-')), substr(cast(f.date as string), 3, 2)), ' '), substr(cast(f.date as string), 5, 2)), ':'), substr(cast(f.date as string), 7, 2)), ':00') as timestamp) as `localdate`, cast(f.delay as int), cast(f.distance as int), f.origin as src, f.destination as dst, o.city as city_src, d.city as city_dst, o.state as state_src, d.state as state_dst from flights f join airports o on o.iata = f.origin join airports d on d.iata = f.destination") 

# Create Temporary View and cache
departureDelays.createOrReplaceTempView("departureDelays")
departureDelays.cache()
DataFrame[tripid: int, localdate: timestamp, delay: int, distance: int, src: string, dst: string, city_src: string, city_dst: string, state_src: string, state_dst: string]
In [10]:
# Note, ensure you have already installed the GraphFrames spack-package
import os
sc.addPyFile(os.path.expanduser('./graphframes_graphframes-0.5.0-spark2.1-s_2.11.jar'))
from pyspark.sql.functions import *
from graphframes import *

# Create Vertices (airports) and Edges (flights)
tripVertices = airportsna.withColumnRenamed("IATA", "id").distinct()
tripEdges = departureDelays.select("tripid", "delay", "src", "dst", "city_dst", "state_dst")

# Cache Vertices and Edges
tripEdges.cache()
tripVertices.cache()

# Create TripGraph
tripGraph = GraphFrame(tripVertices, tripEdges)

What flights departing SEA with the most significant average delays

Note, the joins are there to see the city name instead of the IATA codes. The rank() code is there to help order the data correctly when viewed in Jupyter notebooks.

In [11]:
flightDelays = tripGraph.edges.filter("src = 'SEA' and delay > 0").groupBy("src", "dst").avg("delay").sort(desc("avg(delay)"))
flightDelays.createOrReplaceTempView("flightDelays")
In [12]:
%%sql
select concat(concat((dense_rank() OVER (PARTITION BY 1 ORDER BY avg_delay DESC)-1), '. '), city) as destination, 
avg_delay
from (
select a.city, `avg(delay)` as avg_delay 
from flightDelays f
join airports a
on f.dst = a.iata
order by `avg(delay)` 
desc limit 10
) s
destination avg_delay
0 0. Philadelphia 55.666667
1 1. Colorado Springs 43.538462
2 2. Fresno 43.038462
3 3. Long Beach 39.397059
4 4. Washington DC 37.733333
5 5. Miami 37.325581
6 6. San Francisco 36.502104
7 7. Santa Barbara 36.482759
8 8. New York 35.031250
9 9. Chicago 33.603352

Which is the most important airport (in terms of connections)

It would take a relatively complicated SQL statement to calculate all of the edges to a single vertex, grouped by the vertices. Instead, we can use the graph degree method.

In [13]:
airportConnections = tripGraph.degrees.sort(desc("degree"))
airportConnections.createOrReplaceTempView("airportConnections")
In [14]:
%%sql
select concat(concat((dense_rank() OVER (PARTITION BY 1 ORDER BY degree DESC)-1), '. '), city) as destination, 
degree
from (
select a.city, f.degree 
from airportConnections f 
join airports a
  on a.iata = f.id
order by f.degree desc 
limit 10
) a
destination degree
0 0. Atlanta 179774
1 1. Dallas 133966
2 2. Chicago 125405
3 3. Los Angeles 106853
4 4. Denver 103699
5 5. Houston 85685
6 6. Phoenix 79672
7 7. San Francisco 77635
8 8. Las Vegas 66101
9 9. Charlotte 56103

Are there direct flights between Seattle and San Jose?

In [15]:
filteredPaths = tripGraph.bfs(
    fromExpr = "id = 'SEA'",
    toExpr = "id = 'SJC'",
    maxPathLength = 1)
filteredPaths.show()
+--------------------+--------------------+--------------------+
|                from|                  e0|                  to|
+--------------------+--------------------+--------------------+
|[Seattle,WA,USA,SEA]|[1031855,-1,SEA,S...|[San Jose,CA,USA,...|
|[Seattle,WA,USA,SEA]|[1041215,-4,SEA,S...|[San Jose,CA,USA,...|
|[Seattle,WA,USA,SEA]|[1041855,-5,SEA,S...|[San Jose,CA,USA,...|
|[Seattle,WA,USA,SEA]|[1070710,-3,SEA,S...|[San Jose,CA,USA,...|
|[Seattle,WA,USA,SEA]|[1080710,-1,SEA,S...|[San Jose,CA,USA,...|
|[Seattle,WA,USA,SEA]|[1110600,-11,SEA,...|[San Jose,CA,USA,...|
|[Seattle,WA,USA,SEA]|[1110710,-9,SEA,S...|[San Jose,CA,USA,...|
|[Seattle,WA,USA,SEA]|[1150710,-6,SEA,S...|[San Jose,CA,USA,...|
|[Seattle,WA,USA,SEA]|[1171600,-2,SEA,S...|[San Jose,CA,USA,...|
|[Seattle,WA,USA,SEA]|[1230600,-5,SEA,S...|[San Jose,CA,USA,...|
|[Seattle,WA,USA,SEA]|[1282030,-2,SEA,S...|[San Jose,CA,USA,...|
|[Seattle,WA,USA,SEA]|[1292030,5,SEA,SJ...|[San Jose,CA,USA,...|
|[Seattle,WA,USA,SEA]|[1302030,-2,SEA,S...|[San Jose,CA,USA,...|
|[Seattle,WA,USA,SEA]|[1300710,-9,SEA,S...|[San Jose,CA,USA,...|
|[Seattle,WA,USA,SEA]|[1160720,10,SEA,S...|[San Jose,CA,USA,...|
|[Seattle,WA,USA,SEA]|[2022030,-4,SEA,S...|[San Jose,CA,USA,...|
|[Seattle,WA,USA,SEA]|[2031600,-3,SEA,S...|[San Jose,CA,USA,...|
|[Seattle,WA,USA,SEA]|[2031215,-5,SEA,S...|[San Jose,CA,USA,...|
|[Seattle,WA,USA,SEA]|[2050600,-4,SEA,S...|[San Jose,CA,USA,...|
|[Seattle,WA,USA,SEA]|[2071215,-4,SEA,S...|[San Jose,CA,USA,...|
+--------------------+--------------------+--------------------+
only showing top 20 rows

But are there any direct flights between San Jose and Buffalo?

  • Try maxPathLength = 1 which means one edge (i.e. one flight) between SJC and BUF, i.e. direct flight
  • Try maxPathLength = 2 which means two edges between SJC and BUF, i.e. all the different variations of flights between San Jose and Buffalo with only one stop oever in between?
In [16]:
filteredPaths = tripGraph.bfs(
  fromExpr = "id = 'SJC'",
  toExpr = "id = 'BUF'",
  maxPathLength = 1)
filteredPaths.show()
+----+-----+-------+---+
|City|State|Country| id|
+----+-----+-------+---+
+----+-----+-------+---+
In [17]:
filteredPaths = tripGraph.bfs(
  fromExpr = "id = 'SJC'",
  toExpr = "id = 'BUF'",
  maxPathLength = 2)
filteredPaths.show()
+--------------------+--------------------+-------------------+--------------------+--------------------+
|                from|                  e0|                 v1|                  e1|                  to|
+--------------------+--------------------+-------------------+--------------------+--------------------+
|[San Jose,CA,USA,...|[1022124,0,SJC,BO...|[Boston,MA,USA,BOS]|[1011059,13,BOS,B...|[Buffalo,NY,USA,BUF]|
|[San Jose,CA,USA,...|[1022124,0,SJC,BO...|[Boston,MA,USA,BOS]|[1050635,1,BOS,BU...|[Buffalo,NY,USA,BUF]|
|[San Jose,CA,USA,...|[1022124,0,SJC,BO...|[Boston,MA,USA,BOS]|[1130710,-10,BOS,...|[Buffalo,NY,USA,BUF]|
|[San Jose,CA,USA,...|[1022124,0,SJC,BO...|[Boston,MA,USA,BOS]|[1181445,78,BOS,B...|[Buffalo,NY,USA,BUF]|
|[San Jose,CA,USA,...|[1022124,0,SJC,BO...|[Boston,MA,USA,BOS]|[1190710,-7,BOS,B...|[Buffalo,NY,USA,BUF]|
|[San Jose,CA,USA,...|[1022124,0,SJC,BO...|[Boston,MA,USA,BOS]|[1210710,-2,BOS,B...|[Buffalo,NY,USA,BUF]|
|[San Jose,CA,USA,...|[1022124,0,SJC,BO...|[Boston,MA,USA,BOS]|[1271445,-4,BOS,B...|[Buffalo,NY,USA,BUF]|
|[San Jose,CA,USA,...|[1022124,0,SJC,BO...|[Boston,MA,USA,BOS]|[2021730,-10,BOS,...|[Buffalo,NY,USA,BUF]|
|[San Jose,CA,USA,...|[1022124,0,SJC,BO...|[Boston,MA,USA,BOS]|[2040710,-6,BOS,B...|[Buffalo,NY,USA,BUF]|
|[San Jose,CA,USA,...|[1022124,0,SJC,BO...|[Boston,MA,USA,BOS]|[2041730,14,BOS,B...|[Buffalo,NY,USA,BUF]|
|[San Jose,CA,USA,...|[1022124,0,SJC,BO...|[Boston,MA,USA,BOS]|[2100710,-2,BOS,B...|[Buffalo,NY,USA,BUF]|
|[San Jose,CA,USA,...|[1022124,0,SJC,BO...|[Boston,MA,USA,BOS]|[2140630,-9,BOS,B...|[Buffalo,NY,USA,BUF]|
|[San Jose,CA,USA,...|[1022124,0,SJC,BO...|[Boston,MA,USA,BOS]|[2191445,-10,BOS,...|[Buffalo,NY,USA,BUF]|
|[San Jose,CA,USA,...|[1022124,0,SJC,BO...|[Boston,MA,USA,BOS]|[2211445,24,BOS,B...|[Buffalo,NY,USA,BUF]|
|[San Jose,CA,USA,...|[1022124,0,SJC,BO...|[Boston,MA,USA,BOS]|[2231049,-4,BOS,B...|[Buffalo,NY,USA,BUF]|
|[San Jose,CA,USA,...|[1022124,0,SJC,BO...|[Boston,MA,USA,BOS]|[2261049,22,BOS,B...|[Buffalo,NY,USA,BUF]|
|[San Jose,CA,USA,...|[1022124,0,SJC,BO...|[Boston,MA,USA,BOS]|[3060630,-10,BOS,...|[Buffalo,NY,USA,BUF]|
|[San Jose,CA,USA,...|[1022124,0,SJC,BO...|[Boston,MA,USA,BOS]|[3070630,-9,BOS,B...|[Buffalo,NY,USA,BUF]|
|[San Jose,CA,USA,...|[1022124,0,SJC,BO...|[Boston,MA,USA,BOS]|[3081455,-6,BOS,B...|[Buffalo,NY,USA,BUF]|
|[San Jose,CA,USA,...|[1022124,0,SJC,BO...|[Boston,MA,USA,BOS]|[3091730,-3,BOS,B...|[Buffalo,NY,USA,BUF]|
+--------------------+--------------------+-------------------+--------------------+--------------------+
only showing top 20 rows

In that case, what is the most common transfer point between San Jose and Buffalo?

In [18]:
commonTransferPoint = filteredPaths.groupBy("v1.id", "v1.City").count().orderBy(desc("count"))
commonTransferPoint.createOrReplaceTempView("commonTransferPoint")
In [19]:
%%sql
select concat(concat((dense_rank() OVER (PARTITION BY 1 ORDER BY Trips DESC)-1), '. '), city) as destination, 
Trips
degree
from (
select City, `count` as Trips from commonTransferPoint order by Trips desc limit 10
) a
destination degree
0 0. Las Vegas 107442
1 1. Chicago 87696
2 2. Phoenix 76770
3 3. New York 31968
4 4. Atlanta 28910
5 5. Chicago 20060
6 6. Boston 1488
7 7. Minneapolis 164

Predicting Flight Delays

Extending upon analysis we have done up to this point, can we also predict if a flight will be delayed, on-time, or early based on the available data.

Prepare the Dataset

The first thing we will do is to cleanse the data and apply some labels to our information (e.g. early, on-time, delayed). As well, we will want to remove any rows with NULL values.

In [20]:
# This contains a generated mapping between tripid and airline
#   You can get the file at https://github.com/dennyglee/databricks/blob/master/misc/trip_airline_map.csv
#   For this example, the trip_airline_map.csv file has been pushed to in my mounted bucket.
tripAirlineMapFilePath = "wasb://data@doctorwhostore.blob.core.windows.net/trip_airline_map.csv"
tripAirlineMap = spark.read.csv(tripAirlineMapFilePath, sep=",", header=True)
tripAirlineMap.createOrReplaceTempView("tripAirlineMap")
In [21]:
# Prep dataset
# Including only Seattle and Las Vegas for this demo
flightML = spark.sql("select cast(distance as double) as distance, src as origin, state_src as origin_state, dst as destination, state_dst as destination_state, concat(concat(concat(cast(tripid as string), src), dst), cast((delay + 2000) as string)) as trip_identifier, case when delay <= 0 then 'on-time' else 'delayed' end as flight_status from departureDelays where src IN ('LAS', 'SEA')")
flightML = flightML.dropna().dropDuplicates()
flightML.createOrReplaceTempView("flightML")
In [22]:
# Join flights and airline information
dataset = spark.sql("select f.distance, f.origin, f.origin_state, f.destination, f.destination_state, f.trip_identifier, f.flight_status, m.airline from flightML f join tripAirlineMap m on m.trip_identifier = f.trip_identifier")
dataset = dataset.dropDuplicates()
#dataset = flightML
cols = dataset.columns
In [23]:
dataset.printSchema()
root
 |-- distance: double (nullable = true)
 |-- origin: string (nullable = true)
 |-- origin_state: string (nullable = true)
 |-- destination: string (nullable = true)
 |-- destination_state: string (nullable = true)
 |-- trip_identifier: string (nullable = true)
 |-- flight_status: string (nullable = false)
 |-- airline: string (nullable = true)

Building ML Pipeline

Before we can run our various models against this data, we will first need to vectorize our data via One-Hot Encorder (for category data), String Indexer (create an index based on our labelled values), and Vector Assembler.

In [24]:
# One-Hot Encoding
from pyspark.ml import Pipeline
from pyspark.ml.feature import OneHotEncoder, StringIndexer, VectorAssembler

categoricalColumns = ["origin", "origin_state", "destination", "destination_state", "trip_identifier", "airline"]
stages = [] # stages in our Pipeline
for categoricalCol in categoricalColumns:
  # Category Indexing with StringIndexer
  stringIndexer = StringIndexer(inputCol=categoricalCol, outputCol=categoricalCol+"Index")
  
  # Use OneHotEncoder to convert categorical variables into binary SparseVectors
  encoder = OneHotEncoder(inputCol=categoricalCol+"Index", outputCol=categoricalCol+"classVec")
  
  # Add stages.  These are not run here, but will run all at once later on.
  stages += [stringIndexer, encoder]

# Convert label into label indices using the StringIndexer
label_stringIdx = StringIndexer(inputCol = "flight_status", outputCol = "label")
stages += [label_stringIdx]

# Transform all features into a vector using VectorAssembler
numericCols = ["distance"]
assemblerInputs = map(lambda c: c + "classVec", categoricalColumns) + numericCols
assembler = VectorAssembler(inputCols=assemblerInputs, outputCol="features")
stages += [assembler]
In [25]:
# Create a Pipeline.
pipeline = Pipeline(stages=stages)
# Run the feature transformations.
#  - fit() computes feature statistics as needed.
#  - transform() actually transforms the features.
pipelineModel = pipeline.fit(dataset)
dataset = pipelineModel.transform(dataset)

# Keep relevant columns
selectedcols = ["label", "features"] + cols
dataset = dataset.select(selectedcols)
dataset.show()
+-----+--------------------+--------+------+------------+-----------+-----------------+-----------------+-------------+-------+
|label|            features|distance|origin|origin_state|destination|destination_state|  trip_identifier|flight_status|airline|
+-----+--------------------+--------+------+------------+-----------+-----------------+-----------------+-------------+-------+
|  1.0|(55487,[0,1,22,84...|   194.0|   LAS|          NV|        BUR|               CA|1010900LASBUR2008|      delayed| United|
|  1.0|(55487,[4,87,4935...|   890.0|   SEA|          WA|        DEN|               CO|1011744SEADEN2055|      delayed| Virgin|
|  0.0|(55487,[18,89,205...|   753.0|   SEA|          WA|        LAS|               NV|1020900SEALAS1997|      on-time|  Delta|
|  1.0|(55487,[0,1,67,10...|  1760.0|   LAS|          NV|        RDU|               NC|1020940LASRDU2002|      delayed| United|
|  1.0|(55487,[0,1,5,86,...|   222.0|   LAS|          NV|        PHX|               AZ|1021415LASPHX2297|      delayed|  Delta|
|  0.0|(55487,[0,1,2,84,...|   205.0|   LAS|          NV|        LAX|               CA|1021705LASLAX1999|      on-time| Alaska|
|  0.0|(55487,[20,90,495...|  1259.0|   SEA|          WA|        ANC|               AK|1022130SEAANC1996|      on-time|  Delta|
|  0.0|(55487,[18,89,237...|   753.0|   SEA|          WA|        LAS|               NV|1030600SEALAS1994|      on-time|  Delta|
|  1.0|(55487,[20,90,177...|  1259.0|   SEA|          WA|        ANC|               AK|1032140SEAANC2035|      delayed| United|
|  0.0|(55487,[2,84,5377...|   829.0|   SEA|          WA|        LAX|               CA|1040900SEALAX1998|      on-time|  Delta|
|  1.0|(55487,[19,95,154...|   112.0|   SEA|          WA|        PDX|               OR|1051504SEAPDX2007|      delayed| Virgin|
|  0.0|(55487,[0,1,45,10...|  1830.0|   LAS|          NV|        BWI|               MD|1052255LASBWI1996|      on-time| Alaska|
|  1.0|(55487,[0,1,67,10...|  1760.0|   LAS|          NV|        RDU|               NC|1060940LASRDU2041|      delayed| United|
|  1.0|(55487,[11,84,468...|   583.0|   SEA|          WA|        OAK|               CA|1061000SEAOAK2013|      delayed| United|
|  1.0|(55487,[0,1,21,96...|  1129.0|   LAS|          NV|        MSP|               MN|1061555LASMSP2004|      delayed|  Delta|
|  0.0|(55487,[30,97,429...|  2326.0|   SEA|          WA|        HNL|               HI|1071035SEAHNL1997|      on-time| Alaska|
|  1.0|(55487,[5,86,3505...|   962.0|   SEA|          WA|        PHX|               AZ|1071855SEAPHX2016|      delayed|  Delta|
|  1.0|(55487,[0,1,10,93...|  1518.0|   LAS|          NV|        ATL|               GA|1072355LASATL2046|      delayed|  Delta|
|  0.0|(55487,[20,90,252...|  1259.0|   SEA|          WA|        ANC|               AK|1080800SEAANC1997|      on-time|  Delta|
|  1.0|(55487,[11,84,151...|   583.0|   SEA|          WA|        OAK|               CA|1081620SEAOAK2029|      delayed| United|
+-----+--------------------+--------+------+------------+-----------+-----------------+-----------------+-------------+-------+
only showing top 20 rows

Randomly split data into training and test datasets

  • Set the seed for reproducibility
In [26]:
(trainingData, testData) = dataset.randomSplit([0.7, 0.3], seed = 100)

Logistic Regression

Let's try using logistic regression to see if we can accurately predict if a flight will be delayed.

  • First, we will train the data using Logistic Regression
  • Next we will run that model against the testData
In [27]:
from pyspark.ml.classification import LogisticRegression

# Create initial LogisticRegression model
lr = LogisticRegression(labelCol="label", featuresCol="features", maxIter=10)

# Train model with Training Data
lrModel = lr.fit(trainingData)
In [28]:
# Make predictions on test data using the transform() method.
# LogisticRegression.transform() will only use the 'features' column.
predictions = lrModel.transform(testData)

View LR Model's predictions

  • Recall, label is the actual test value, prediction is the predicted value
    • where 0 - on-time, 1 - delayed
In [29]:
selected = predictions.select("label", "prediction", "probability", "flight_status", "destination", "destination_state").where("destination = 'SEA'")
selected.show()
+-----+----------+--------------------+-------------+-----------+-----------------+
|label|prediction|         probability|flight_status|destination|destination_state|
+-----+----------+--------------------+-------------+-----------+-----------------+
|  0.0|       0.0|[0.77185795884037...|      on-time|        SEA|               WA|
|  0.0|       0.0|[0.77185795884037...|      on-time|        SEA|               WA|
|  0.0|       0.0|[0.77185795884037...|      on-time|        SEA|               WA|
|  1.0|       0.0|[0.77185795884037...|      delayed|        SEA|               WA|
|  0.0|       0.0|[0.98241960097150...|      on-time|        SEA|               WA|
|  0.0|       0.0|[0.98241960097150...|      on-time|        SEA|               WA|
|  0.0|       0.0|[0.98241960097150...|      on-time|        SEA|               WA|
|  0.0|       0.0|[0.77185795884037...|      on-time|        SEA|               WA|
|  0.0|       0.0|[0.98241960097150...|      on-time|        SEA|               WA|
|  1.0|       1.0|[0.13666628633447...|      delayed|        SEA|               WA|
|  0.0|       0.0|[0.98241960097150...|      on-time|        SEA|               WA|
|  0.0|       0.0|[0.98241960097150...|      on-time|        SEA|               WA|
|  0.0|       0.0|[0.77185795884037...|      on-time|        SEA|               WA|
|  0.0|       0.0|[0.77185795884037...|      on-time|        SEA|               WA|
|  1.0|       0.0|[0.77185795884037...|      delayed|        SEA|               WA|
|  0.0|       0.0|[0.77185795884037...|      on-time|        SEA|               WA|
|  1.0|       1.0|[0.13666628633447...|      delayed|        SEA|               WA|
|  0.0|       0.0|[0.98241960097150...|      on-time|        SEA|               WA|
|  0.0|       0.0|[0.77185795884037...|      on-time|        SEA|               WA|
|  1.0|       0.0|[0.77185795884037...|      delayed|        SEA|               WA|
+-----+----------+--------------------+-------------+-----------+-----------------+
only showing top 20 rows

Evaluate our model

Let's use the BinaryClassificationEvaluator to determine the precision of our model.

In [30]:
from pyspark.ml.evaluation import BinaryClassificationEvaluator

# Evaluate model
evaluator = BinaryClassificationEvaluator(rawPredictionCol="rawPrediction")
evaluator.evaluate(predictions)
0.7740006537020221
In [31]:
predictions.createOrReplaceTempView("predictions")
In [32]:
%%sql
select * from predictions limit 10
label features distance origin origin_state destination destination_state trip_identifier flight_status airline rawPrediction probability prediction
0 0 {u'type': 0, u'size': 55487, u'indices': [0, 1... 205 LAS NV LAX CA 3211705LASLAX1994 on-time Alaska {u'type': 1, u'values': [3.42778343321, -3.427... {u'type': 1, u'values': [0.968561643378, 0.031... 0
1 0 {u'type': 0, u'size': 55487, u'indices': [0, 1... 205 LAS NV LAX CA 2191155LASLAX1998 on-time Alaska {u'type': 1, u'values': [3.42778343321, -3.427... {u'type': 1, u'values': [0.968561643378, 0.031... 0
2 0 {u'type': 0, u'size': 55487, u'indices': [0, 1... 222 LAS NV PHX AZ 2241605LASPHX2000 on-time Delta {u'type': 1, u'values': [0.0724148894375, -0.0... {u'type': 1, u'values': [0.518095815305, 0.481... 0
3 0 {u'type': 0, u'size': 55487, u'indices': [0, 1... 222 LAS NV PHX AZ 3110625LASPHX1997 on-time Alaska {u'type': 1, u'values': [2.87681668946, -2.876... {u'type': 1, u'values': [0.946688431947, 0.053... 0
4 0 {u'type': 0, u'size': 55487, u'indices': [0, 1... 1316 LAS NV ORD IL 1280640LASORD1995 on-time Delta {u'type': 1, u'values': [0.332940396795, -0.33... {u'type': 1, u'values': [0.582474648207, 0.417... 0
5 0 {u'type': 0, u'size': 55487, u'indices': [0, 1... 225 LAS NV SAN CA 1151535LASSAN1996 on-time Alaska {u'type': 1, u'values': [2.66412407916, -2.664... {u'type': 1, u'values': [0.934876202365, 0.065... 0
6 0 {u'type': 0, u'size': 55487, u'indices': [0, 1... 225 LAS NV SAN CA 3281205LASSAN2000 on-time Delta {u'type': 1, u'values': [-0.140277720865, 0.14... {u'type': 1, u'values': [0.464987964396, 0.535... 1
7 0 {u'type': 0, u'size': 55487, u'indices': [0, 1... 1518 LAS NV ATL GA 3131105LASATL1994 on-time Alaska {u'type': 1, u'values': [3.21397942262, -3.213... {u'type': 1, u'values': [0.961356971792, 0.038... 0
8 0 {u'type': 0, u'size': 55487, u'indices': [0, 1... 1953 LAS NV JFK NY 3262215LASJFK1999 on-time Alaska {u'type': 1, u'values': [3.14994749752, -3.149... {u'type': 1, u'values': [0.958906653006, 0.041... 0
9 0 {u'type': 0, u'size': 55487, u'indices': [0, 1... 1953 LAS NV JFK NY 3210630LASJFK1999 on-time Delta {u'type': 1, u'values': [0.345545697496, -0.34... {u'type': 1, u'values': [0.585537007043, 0.414... 0
In [33]:
%%sql
select confusion, count(1)
from (
select 
case
  when label = 0 and prediction = 0 then 'True Positives (On-Time, Predicted: On-Time)'
  when label = 0 and prediction = 1 then 'False Negatives (On-Time, Predicted: Delayed)'
  when label = 1 and prediction = 0 then 'False Postitives (Delayed, Preidcted: On-Time)'
  when label = 1 and prediction = 1 then 'True Negatives (Delayed, Predicted: Delayed)'
end as confusion
from predictions
) a
group by confusion
confusion count(1)
0 False Negatives (On-Time, Predicted: Delayed) 2017
1 False Postitives (Delayed, Preidcted: On-Time) 2805
2 True Negatives (Delayed, Predicted: Delayed) 4579
3 True Positives (On-Time, Predicted: On-Time) 7101

Plot decision boundary of the logistic regression model

In [34]:
# Take a sample of 200 to plot decision boundary of logistic regression
predictions_sample = predictions.sample(False, 0.1, seed=100).limit(200)
predictions_sample.createOrReplaceTempView("predictions_sample")

# Build schema for data
from pyspark.sql.types import *
schema = StructType([
        StructField('X1', FloatType(), True), 
        StructField('X2', StringType(), True),
        StructField('y', IntegerType(), True)])

# Extract data
ps = predictions_sample.rdd.map(lambda p: (float(p.rawPrediction[0]), p.airline, p.label) )

# Convert to DataFrame
psDF = ps.toDF(["X1", "X2", "y"], schema)
psDF.createOrReplaceTempView("psDF")
In [35]:
%%sql -o pyX
select `X1`, case `X2` when 'Alaska' then 1.0 when 'United' then 2.0 when 'Delta' then 3.0 when 'Virgin' then 4.0 end as `X2` from psDF
X1 X2
0 0.192810 3
1 0.195680 3
2 3.137342 1
3 -0.527650 4
4 -0.751916 3
5 -1.597688 2
6 0.332940 3
7 1.470000 3
8 3.232925 1
9 0.623382 3
10 -3.365150 2
11 1.349276 3
12 1.464403 3
13 0.508745 3
14 -2.869282 2
15 0.072415 3
16 -4.436605 2
17 -1.597688 2
18 -2.989676 2
19 0.345546 3
20 2.054653 3
21 3.129571 1
22 3.149947 1
23 -3.849939 2
24 -2.071480 2
25 -1.006322 4
26 -1.712815 2
27 -0.647341 3
28 -3.573319 2
29 -2.451185 2
... ... ...
170 -0.647341 3
171 -4.208075 2
172 -1.592092 2
173 -3.072791 2
174 1.218832 3
175 0.402039 3
176 1.668325 3
177 2.398786 1
178 4.197101 3
179 -2.438710 2
180 -2.652514 2
181 -4.069861 2
182 0.623382 3
183 0.623382 3
184 -4.069861 2
185 1.349276 3
186 1.464403 3
187 1.470000 3
188 3.701857 1
189 0.508745 3
190 1.290436 3
191 1.218832 3
192 1.334464 3
193 0.623382 3
194 -2.553347 2
195 0.274129 3
196 1.455413 3
197 4.227354 1
198 3.994874 1
199 2.418639 3

200 rows × 2 columns

In [36]:
%%sql -o pyY
select 1.*y as `y` from psDF
y
0 0
1 0
2 0
3 0
4 1
5 1
6 1
7 1
8 0
9 0
10 0
11 0
12 0
13 0
14 1
15 1
16 1
17 1
18 1
19 0
20 0
21 0
22 1
23 1
24 0
25 0
26 0
27 0
28 0
29 1
... ...
170 1
171 1
172 0
173 0
174 0
175 0
176 0
177 1
178 1
179 0
180 0
181 0
182 1
183 1
184 1
185 1
186 1
187 1
188 1
189 1
190 0
191 1
192 1
193 0
194 0
195 0
196 0
197 0
198 1
199 1

200 rows × 1 columns

In [37]:
%%local
%matplotlib inline
#
# Reference: https://stackoverflow.com/a/28257799/1100699
#
import numpy as np
from sklearn.linear_model import LogisticRegression
from sklearn.datasets import make_classification
import matplotlib.pyplot as plt

#X, y = make_classification(200, 2, 2, 0, weights=[.5, .5], random_state=15)
X = np.asarray(pyX)
z = np.asarray(pyY, dtype=np.int)
y = z.ravel() 

clf = LogisticRegression().fit(X[:100], y[:100])

xx, yy = np.mgrid[-5:5:.01, -5:5:.01]
grid = np.c_[xx.ravel(), yy.ravel()]
probs = clf.predict_proba(grid)[:, 1].reshape(xx.shape)

f, ax = plt.subplots(figsize=(8, 6))
contour = ax.contourf(xx, yy, probs, 25, cmap="RdBu",
                      vmin=0, vmax=1)
ax_c = f.colorbar(contour)
ax_c.set_label("$P(y = 1)$")
ax_c.set_ticks([0, .25, .5, .75, 1])

ax.scatter(X[100:,0], X[100:, 1], c=y[100:], s=50,
           cmap="RdBu", vmin=-.2, vmax=1.2,
           edgecolor="white", linewidth=1)

ax.set(aspect="equal",
       xlim=(-5, 5), ylim=(-5, 5),
       xlabel="$X_1$", ylabel="$X_2$")
Out[37]:
[(-5, 5),
 <matplotlib.text.Text at 0x7f1b1cb36a50>,
 (-5, 5),
 <matplotlib.text.Text at 0x7f1b1cc5be10>,
 None]
In [ ]: